{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:06.194795Z",
     "iopub.status.busy": "2025-01-23T08:17:06.194511Z",
     "iopub.status.idle": "2025-01-23T08:17:08.288087Z",
     "shell.execute_reply": "2025-01-23T08:17:08.287432Z",
     "shell.execute_reply.started": "2025-01-23T08:17:06.194772Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.735873Z",
     "start_time": "2025-03-01T09:06:11.728467Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 2.0.2\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.1\n",
      "torch 2.6.0+cu126\n",
      "cuda:0\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.289369Z",
     "iopub.status.busy": "2025-01-23T08:17:08.289041Z",
     "iopub.status.idle": "2025-01-23T08:17:08.292186Z",
     "shell.execute_reply": "2025-01-23T08:17:08.291580Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.289346Z"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.784882Z",
     "start_time": "2025-03-01T09:06:11.781881Z"
    }
   },
   "source": [
    "# !wget https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt"
   ],
   "outputs": [],
   "execution_count": 15
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.292951Z",
     "iopub.status.busy": "2025-01-23T08:17:08.292769Z",
     "iopub.status.idle": "2025-01-23T08:17:08.297467Z",
     "shell.execute_reply": "2025-01-23T08:17:08.296818Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.292931Z"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.789679Z",
     "start_time": "2025-03-01T09:06:11.785882Z"
    }
   },
   "source": [
    "# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n",
    "#文件已经下载好了\n",
    "with open(\"./shakespeare.txt\", \"r\", encoding=\"utf8\") as file:\n",
    "    text = file.read()\n",
    "\n",
    "print(\"length\", len(text))\n",
    "print(text[0:100])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "length 1115394\n",
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造字典"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.298036Z",
     "iopub.status.busy": "2025-01-23T08:17:08.297868Z",
     "iopub.status.idle": "2025-01-23T08:17:08.318373Z",
     "shell.execute_reply": "2025-01-23T08:17:08.317661Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.298017Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.806296Z",
     "start_time": "2025-03-01T09:06:11.798678Z"
    }
   },
   "source": [
    "# 1. generate vocab\n",
    "# 2. build mapping char->id\n",
    "# 3. data -> id_data  把数据都转为id\n",
    "# 4. a b c d [EOS] -> [BOS] b c d  预测下一个字符生成的模型，也就是输入是a，输出就是b\n",
    "\n",
    "#去重，留下独立字符，并排序（排序是为了好看）\n",
    "vocab = sorted(set(text))\n",
    "print(len(vocab))\n",
    "print(vocab)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65\n",
      "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.320835Z",
     "iopub.status.busy": "2025-01-23T08:17:08.320398Z",
     "iopub.status.idle": "2025-01-23T08:17:08.324538Z",
     "shell.execute_reply": "2025-01-23T08:17:08.323832Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.320803Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.817545Z",
     "start_time": "2025-03-01T09:06:11.814294Z"
    }
   },
   "source": [
    "for idx,char in enumerate(['h','o','w']):\n",
    "    print(idx,char)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 h\n",
      "1 o\n",
      "2 w\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.325508Z",
     "iopub.status.busy": "2025-01-23T08:17:08.325251Z",
     "iopub.status.idle": "2025-01-23T08:17:08.328898Z",
     "shell.execute_reply": "2025-01-23T08:17:08.328318Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.325470Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.848808Z",
     "start_time": "2025-03-01T09:06:11.845549Z"
    }
   },
   "source": [
    "#每个字符都编好号，enumerate对每一个位置编号，生成的是列表中是元组，下面字典生成式\n",
    "char2idx = {char:idx for idx, char in enumerate(vocab)}\n",
    "print(char2idx)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.330015Z",
     "iopub.status.busy": "2025-01-23T08:17:08.329574Z",
     "iopub.status.idle": "2025-01-23T08:17:08.333345Z",
     "shell.execute_reply": "2025-01-23T08:17:08.332744Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.329984Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.864809Z",
     "start_time": "2025-03-01T09:06:11.859808Z"
    }
   },
   "source": [
    "# 把vocab从列表变为ndarray\n",
    "idx2char = np.array(vocab)\n",
    "print(idx2char)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
      " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
      " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
      " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
     ]
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.334411Z",
     "iopub.status.busy": "2025-01-23T08:17:08.334017Z",
     "iopub.status.idle": "2025-01-23T08:17:08.451298Z",
     "shell.execute_reply": "2025-01-23T08:17:08.450748Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.334380Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.959025Z",
     "start_time": "2025-03-01T09:06:11.865809Z"
    }
   },
   "source": [
    "#把字符都转换为id\n",
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "print(text_as_int.shape)\n",
    "print(len(text_as_int))\n",
    "print(text_as_int[0:10])\n",
    "print(text[0:10])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1115394,)\n",
      "1115394\n",
      "[18 47 56 57 58  1 15 47 58 47]\n",
      "First Citi\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.452389Z",
     "iopub.status.busy": "2025-01-23T08:17:08.451936Z",
     "iopub.status.idle": "2025-01-23T08:17:08.456739Z",
     "shell.execute_reply": "2025-01-23T08:17:08.456178Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.452357Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.963028Z",
     "start_time": "2025-03-01T09:06:11.960023Z"
    }
   },
   "source": [
    "1115394//101"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "11043"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 把莎士比亚文集分成一个一个的样本"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.457842Z",
     "iopub.status.busy": "2025-01-23T08:17:08.457436Z",
     "iopub.status.idle": "2025-01-23T08:17:08.464781Z",
     "shell.execute_reply": "2025-01-23T08:17:08.464218Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.457810Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:11.967960Z",
     "start_time": "2025-03-01T09:06:11.963539Z"
    }
   },
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class CharDataset(Dataset):\n",
    "    #text_as_int是字符的id列表，seq_length是每个样本的长度\n",
    "    def __init__(self, text_as_int, seq_length):\n",
    "        self.sub_len = seq_length + 1 #一个样本的长度\n",
    "        self.text_as_int = text_as_int\n",
    "        self.num_seq = len(text_as_int) // self.sub_len #样本的个数\n",
    "        \n",
    "    def __getitem__(self, index):#index是样本的索引，返回的是一个样本，比如第一个，就是0-100的字符,总计101个字符\n",
    "        return self.text_as_int[index * self.sub_len: (index + 1) * self.sub_len]\n",
    "    \n",
    "    def __len__(self): #返回样本的个数\n",
    "        return self.num_seq\n",
    "\n",
    "#batch是一个列表，列表中的每一个元素是一个样本，有101个字符，前100个是输入，后100个是输出\n",
    "def collat_fct(batch):\n",
    "    src_list = [] #输入\n",
    "    trg_list = [] #输出\n",
    "    for part in batch:\n",
    "        src_list.append(part[:-1]) #输入\n",
    "        trg_list.append(part[1:]) #输出\n",
    "        \n",
    "    src_list = np.array(src_list) #把列表转换为ndarray\n",
    "    trg_list = np.array(trg_list) #把列表转换为ndarray\n",
    "    return torch.Tensor(src_list).to(dtype=torch.int64), torch.Tensor(trg_list).to(dtype=torch.int64) #返回的是一个元组，元组中的每一个元素是一个torch.Tensor\n",
    "        \n",
    "#每个样本的长度是101，也就是100个字符+1个结束符\n",
    "train_ds = CharDataset(text_as_int, 100)\n",
    "train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collat_fct)"
   ],
   "outputs": [],
   "execution_count": 23
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.465951Z",
     "iopub.status.busy": "2025-01-23T08:17:08.465498Z",
     "iopub.status.idle": "2025-01-23T08:17:08.471236Z",
     "shell.execute_reply": "2025-01-23T08:17:08.470746Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.465920Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.004966Z",
     "start_time": "2025-03-01T09:06:11.968957Z"
    }
   },
   "source": [
    "for datas, labels in train_dl:\n",
    "    print(datas.shape)\n",
    "    print(labels.shape)\n",
    "    break"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 100])\n",
      "torch.Size([64, 100])\n"
     ]
    }
   ],
   "execution_count": 24
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.472221Z",
     "iopub.status.busy": "2025-01-23T08:17:08.471840Z",
     "iopub.status.idle": "2025-01-23T08:17:08.487855Z",
     "shell.execute_reply": "2025-01-23T08:17:08.487375Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.472190Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.016963Z",
     "start_time": "2025-03-01T09:06:12.005966Z"
    }
   },
   "source": [
    "class CharRNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024):\n",
    "        super(CharRNN, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        #batch_first=True,输入的数据格式是(batch_size, seq_len, embedding_dim)\n",
    "        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_dim, vocab_size)\n",
    "        \n",
    "    def forward(self, x, hidden=None):\n",
    "        x = self.embedding(x) #(batch_size, seq_len) -> (batch_size, seq_len, embedding_dim) (64, 100, 256)\n",
    "        #这里和02的差异是没有只拿最后一个输出，而是把所有的输出都拿出来了\n",
    "        #(batch_size, seq_len, embedding_dim)->(batch_size, seq_len, hidden_dim)(64, 100, 1024)\n",
    "        output, hidden = self.rnn(x, hidden)\n",
    "        x = self.fc(output) #[bs, seq_len, hidden_dim]--->[bs, seq_len, vocab_size] (64, 100,65)\n",
    "        return x, hidden #x的shape是(batch_size, seq_len, vocab_size)\n",
    "    \n",
    "    \n",
    "vocab_size = len(vocab)\n",
    "\n",
    "print(\"{:=^80}\".format(\" 一层单向 RNN \"))       \n",
    "for key, value in CharRNN(vocab_size).named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================== 一层单向 RNN ===================================\n",
      "            embedding.weight            paramerters num: 16640\n",
      "            rnn.weight_ih_l0            paramerters num: 262144\n",
      "            rnn.weight_hh_l0            paramerters num: 1048576\n",
      "             rnn.bias_ih_l0             paramerters num: 1024\n",
      "             rnn.bias_hh_l0             paramerters num: 1024\n",
      "               fc.weight                paramerters num: 66560\n",
      "                fc.bias                 paramerters num: 65\n"
     ]
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.488861Z",
     "iopub.status.busy": "2025-01-23T08:17:08.488421Z",
     "iopub.status.idle": "2025-01-23T08:17:08.493065Z",
     "shell.execute_reply": "2025-01-23T08:17:08.492525Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.488829Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.021430Z",
     "start_time": "2025-03-01T09:06:12.016963Z"
    }
   },
   "source": [
    "print(65*256)\n",
    "print(256*1024)\n",
    "print(1024*1024)\n",
    "1024*65"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16640\n",
      "262144\n",
      "1048576\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "66560"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 26
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.496358Z",
     "iopub.status.busy": "2025-01-23T08:17:08.496008Z",
     "iopub.status.idle": "2025-01-23T08:17:08.520922Z",
     "shell.execute_reply": "2025-01-23T08:17:08.520382Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.496327Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.050301Z",
     "start_time": "2025-03-01T09:06:12.021430Z"
    }
   },
   "source": [
    "sample_inputs = torch.randint(0, vocab_size, (3, 100))\n",
    "print(sample_inputs.shape)\n",
    "model = CharRNN(vocab_size)\n",
    "output=model(sample_inputs)\n",
    "output[0].shape"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 100])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 100, 65])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 27
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.521739Z",
     "iopub.status.busy": "2025-01-23T08:17:08.521557Z",
     "iopub.status.idle": "2025-01-23T08:17:08.525010Z",
     "shell.execute_reply": "2025-01-23T08:17:08.524573Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.521720Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.054301Z",
     "start_time": "2025-03-01T09:06:12.051299Z"
    }
   },
   "source": [
    "output[0].reshape(-1, vocab_size).shape"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([300, 65])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 28
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.525644Z",
     "iopub.status.busy": "2025-01-23T08:17:08.525470Z",
     "iopub.status.idle": "2025-01-23T08:17:08.528873Z",
     "shell.execute_reply": "2025-01-23T08:17:08.528260Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.525625Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.058493Z",
     "start_time": "2025-03-01T09:06:12.055300Z"
    }
   },
   "source": [
    "\n",
    "1024*65"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "66560"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 29
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.529522Z",
     "iopub.status.busy": "2025-01-23T08:17:08.529347Z",
     "iopub.status.idle": "2025-01-23T08:17:08.534592Z",
     "shell.execute_reply": "2025-01-23T08:17:08.534155Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.529503Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:12.064093Z",
     "start_time": "2025-03-01T09:06:12.059490Z"
    }
   },
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        \"\"\"\n",
    "        Save checkpoints each save_epoch epoch. \n",
    "        We save checkpoint by epoch in this implementation.\n",
    "        Usually, training scripts with pytorch evaluating model and save checkpoint by step.\n",
    "\n",
    "        Args:\n",
    "            save_dir (str): dir to save checkpoint\n",
    "            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.\n",
    "            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.\n",
    "        \"\"\"\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = -1\n",
    "        \n",
    "        # mkdir\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "        \n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return\n",
    "        \n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                # save checkpoints\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"best.ckpt\"))\n",
    "                # update best metrics\n",
    "                self.best_metrics = metric\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "\n"
   ],
   "outputs": [],
   "execution_count": 30
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:08.535220Z",
     "iopub.status.busy": "2025-01-23T08:17:08.535056Z",
     "iopub.status.idle": "2025-01-23T08:17:09.444746Z",
     "shell.execute_reply": "2025-01-23T08:17:09.444183Z",
     "shell.execute_reply.started": "2025-01-23T08:17:08.535201Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:06:14.074149Z",
     "start_time": "2025-03-01T09:06:12.080094Z"
    }
   },
   "source": [
    "# 训练\n",
    "def training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    save_ckpt_callback=None,\n",
    "    stateful=False      # 想用stateful，batch里的数据就必须连续，不能打乱\n",
    "    ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "    }\n",
    "    \n",
    "    global_step = 0\n",
    "    model.train()\n",
    "    hidden = None\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            # training\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "                # 模型前向计算,如果数据集打乱了，stateful=False，hidden就要清空\n",
    "                # 如果数据集没有打乱，stateful=True，hidden就不需要清空\n",
    "                logits, hidden = model(datas, hidden=hidden if stateful else None)\n",
    "                # 计算损失,交叉熵损失第一个参数要是二阶张量，第二个参数要是一阶张量，所以要reshape\n",
    "                loss = loss_fct(logits.reshape(-1, vocab_size), labels.reshape(-1))\n",
    "                # 梯度回传\n",
    "                loss.backward()\n",
    "                # 调整优化器，包括学习率的变动等\n",
    "                optimizer.step()\n",
    " \n",
    "                loss = loss.cpu().item()\n",
    "                # record\n",
    "                \n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"step\": global_step\n",
    "                })\n",
    "   \n",
    "                # 保存模型权重 save model checkpoint\n",
    "                if save_ckpt_callback is not None:\n",
    "                    save_ckpt_callback(global_step, model.state_dict(), metric=-loss)\n",
    "                # udate step\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "        \n",
    "    return record_dict\n",
    "        \n",
    "\n",
    "epoch = 100\n",
    "\n",
    "model = CharRNN(vocab_size=vocab_size)\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失 \n",
    "loss_fct = nn.CrossEntropyLoss()\n",
    "# 2. 定义优化器 采用 adam\n",
    "# Optimizers specified in the torch.optim package\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "# save best\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(\"checkpoints/text_generation\", save_step=1000, save_best_only=True)\n",
    "\n",
    "\n",
    "model = model.to(device)\n"
   ],
   "outputs": [],
   "execution_count": 31
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:17:09.445822Z",
     "iopub.status.busy": "2025-01-23T08:17:09.445359Z",
     "iopub.status.idle": "2025-01-23T08:19:08.356217Z",
     "shell.execute_reply": "2025-01-23T08:19:08.355733Z",
     "shell.execute_reply.started": "2025-01-23T08:17:09.445799Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:08:44.806937Z",
     "start_time": "2025-03-01T09:06:14.075147Z"
    }
   },
   "source": [
    "record = training(\n",
    "    model,\n",
    "    train_dl,\n",
    "    epoch,\n",
    "    loss_fct,\n",
    "    optimizer,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    "    )"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/17300 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f164394b1377465abce96c49a06be98b"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[32], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m record \u001B[38;5;241m=\u001B[39m \u001B[43mtraining\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m      2\u001B[0m \u001B[43m    \u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      3\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtrain_dl\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      4\u001B[0m \u001B[43m    \u001B[49m\u001B[43mepoch\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      5\u001B[0m \u001B[43m    \u001B[49m\u001B[43mloss_fct\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      6\u001B[0m \u001B[43m    \u001B[49m\u001B[43moptimizer\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      7\u001B[0m \u001B[43m    \u001B[49m\u001B[43msave_ckpt_callback\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msave_ckpt_callback\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      8\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[1;32mIn[31], line 34\u001B[0m, in \u001B[0;36mtraining\u001B[1;34m(model, train_loader, epoch, loss_fct, optimizer, save_ckpt_callback, stateful)\u001B[0m\n\u001B[0;32m     32\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n\u001B[0;32m     33\u001B[0m \u001B[38;5;66;03m# 调整优化器，包括学习率的变动等\u001B[39;00m\n\u001B[1;32m---> 34\u001B[0m \u001B[43moptimizer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstep\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     36\u001B[0m loss \u001B[38;5;241m=\u001B[39m loss\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mitem()\n\u001B[0;32m     37\u001B[0m \u001B[38;5;66;03m# record\u001B[39;00m\n",
      "File \u001B[1;32m~\\AppData\\Roaming\\Python\\Python312\\site-packages\\torch\\optim\\optimizer.py:493\u001B[0m, in \u001B[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m    488\u001B[0m         \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m    489\u001B[0m             \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(\n\u001B[0;32m    490\u001B[0m                 \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfunc\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mresult\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    491\u001B[0m             )\n\u001B[1;32m--> 493\u001B[0m out \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    494\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_optimizer_step_code()\n\u001B[0;32m    496\u001B[0m \u001B[38;5;66;03m# call optimizer step post hooks\u001B[39;00m\n",
      "File \u001B[1;32m~\\AppData\\Roaming\\Python\\Python312\\site-packages\\torch\\optim\\optimizer.py:91\u001B[0m, in \u001B[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m     89\u001B[0m     torch\u001B[38;5;241m.\u001B[39mset_grad_enabled(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdefaults[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mdifferentiable\u001B[39m\u001B[38;5;124m\"\u001B[39m])\n\u001B[0;32m     90\u001B[0m     torch\u001B[38;5;241m.\u001B[39m_dynamo\u001B[38;5;241m.\u001B[39mgraph_break()\n\u001B[1;32m---> 91\u001B[0m     ret \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     92\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[0;32m     93\u001B[0m     torch\u001B[38;5;241m.\u001B[39m_dynamo\u001B[38;5;241m.\u001B[39mgraph_break()\n",
      "File \u001B[1;32m~\\AppData\\Roaming\\Python\\Python312\\site-packages\\torch\\optim\\adam.py:244\u001B[0m, in \u001B[0;36mAdam.step\u001B[1;34m(self, closure)\u001B[0m\n\u001B[0;32m    232\u001B[0m     beta1, beta2 \u001B[38;5;241m=\u001B[39m group[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mbetas\u001B[39m\u001B[38;5;124m\"\u001B[39m]\n\u001B[0;32m    234\u001B[0m     has_complex \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_init_group(\n\u001B[0;32m    235\u001B[0m         group,\n\u001B[0;32m    236\u001B[0m         params_with_grad,\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    241\u001B[0m         state_steps,\n\u001B[0;32m    242\u001B[0m     )\n\u001B[1;32m--> 244\u001B[0m     \u001B[43madam\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m    245\u001B[0m \u001B[43m        \u001B[49m\u001B[43mparams_with_grad\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    246\u001B[0m \u001B[43m        \u001B[49m\u001B[43mgrads\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    247\u001B[0m \u001B[43m        \u001B[49m\u001B[43mexp_avgs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    248\u001B[0m \u001B[43m        \u001B[49m\u001B[43mexp_avg_sqs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    249\u001B[0m \u001B[43m        \u001B[49m\u001B[43mmax_exp_avg_sqs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    250\u001B[0m \u001B[43m        \u001B[49m\u001B[43mstate_steps\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    251\u001B[0m \u001B[43m        \u001B[49m\u001B[43mamsgrad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mamsgrad\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    252\u001B[0m \u001B[43m        \u001B[49m\u001B[43mhas_complex\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mhas_complex\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    253\u001B[0m \u001B[43m        \u001B[49m\u001B[43mbeta1\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbeta1\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    254\u001B[0m \u001B[43m        \u001B[49m\u001B[43mbeta2\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbeta2\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    255\u001B[0m \u001B[43m        \u001B[49m\u001B[43mlr\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mlr\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    256\u001B[0m \u001B[43m        \u001B[49m\u001B[43mweight_decay\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mweight_decay\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    257\u001B[0m \u001B[43m        \u001B[49m\u001B[43meps\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43meps\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    258\u001B[0m \u001B[43m        \u001B[49m\u001B[43mmaximize\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mmaximize\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    259\u001B[0m \u001B[43m        \u001B[49m\u001B[43mforeach\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mforeach\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    260\u001B[0m \u001B[43m        \u001B[49m\u001B[43mcapturable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mcapturable\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    261\u001B[0m \u001B[43m        \u001B[49m\u001B[43mdifferentiable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mdifferentiable\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    262\u001B[0m \u001B[43m        \u001B[49m\u001B[43mfused\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgroup\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mfused\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    263\u001B[0m \u001B[43m        \u001B[49m\u001B[43mgrad_scale\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mgetattr\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mgrad_scale\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    264\u001B[0m \u001B[43m        \u001B[49m\u001B[43mfound_inf\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mgetattr\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mfound_inf\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    265\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    267\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m loss\n",
      "File \u001B[1;32m~\\AppData\\Roaming\\Python\\Python312\\site-packages\\torch\\optim\\optimizer.py:154\u001B[0m, in \u001B[0;36m_disable_dynamo_if_unsupported.<locals>.wrapper.<locals>.maybe_fallback\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m    152\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m disabled_func(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m    153\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 154\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32m~\\AppData\\Roaming\\Python\\Python312\\site-packages\\torch\\optim\\adam.py:876\u001B[0m, in \u001B[0;36madam\u001B[1;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, has_complex, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001B[0m\n\u001B[0;32m    873\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m    874\u001B[0m     func \u001B[38;5;241m=\u001B[39m _single_tensor_adam\n\u001B[1;32m--> 876\u001B[0m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m    877\u001B[0m \u001B[43m    \u001B[49m\u001B[43mparams\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    878\u001B[0m \u001B[43m    \u001B[49m\u001B[43mgrads\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    879\u001B[0m \u001B[43m    \u001B[49m\u001B[43mexp_avgs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    880\u001B[0m \u001B[43m    \u001B[49m\u001B[43mexp_avg_sqs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    881\u001B[0m \u001B[43m    \u001B[49m\u001B[43mmax_exp_avg_sqs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    882\u001B[0m \u001B[43m    \u001B[49m\u001B[43mstate_steps\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    883\u001B[0m \u001B[43m    \u001B[49m\u001B[43mamsgrad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mamsgrad\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    884\u001B[0m \u001B[43m    \u001B[49m\u001B[43mhas_complex\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mhas_complex\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    885\u001B[0m \u001B[43m    \u001B[49m\u001B[43mbeta1\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbeta1\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    886\u001B[0m \u001B[43m    \u001B[49m\u001B[43mbeta2\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbeta2\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    887\u001B[0m \u001B[43m    \u001B[49m\u001B[43mlr\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mlr\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    888\u001B[0m \u001B[43m    \u001B[49m\u001B[43mweight_decay\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mweight_decay\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    889\u001B[0m \u001B[43m    \u001B[49m\u001B[43meps\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43meps\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    890\u001B[0m \u001B[43m    \u001B[49m\u001B[43mmaximize\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmaximize\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    891\u001B[0m \u001B[43m    \u001B[49m\u001B[43mcapturable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcapturable\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    892\u001B[0m \u001B[43m    \u001B[49m\u001B[43mdifferentiable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mdifferentiable\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    893\u001B[0m \u001B[43m    \u001B[49m\u001B[43mgrad_scale\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgrad_scale\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    894\u001B[0m \u001B[43m    \u001B[49m\u001B[43mfound_inf\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mfound_inf\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m    895\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32m~\\AppData\\Roaming\\Python\\Python312\\site-packages\\torch\\optim\\adam.py:601\u001B[0m, in \u001B[0;36m_multi_tensor_adam\u001B[1;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, has_complex, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable)\u001B[0m\n\u001B[0;32m    596\u001B[0m \u001B[38;5;66;03m# Update steps\u001B[39;00m\n\u001B[0;32m    597\u001B[0m \u001B[38;5;66;03m# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over\u001B[39;00m\n\u001B[0;32m    598\u001B[0m \u001B[38;5;66;03m# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just\u001B[39;00m\n\u001B[0;32m    599\u001B[0m \u001B[38;5;66;03m# wrapped it once now. The alpha is required to assure we go to the right overload.\u001B[39;00m\n\u001B[0;32m    600\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mcompiler\u001B[38;5;241m.\u001B[39mis_compiling() \u001B[38;5;129;01mand\u001B[39;00m device_state_steps[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;241m.\u001B[39mis_cpu:\n\u001B[1;32m--> 601\u001B[0m     \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_foreach_add_\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m    602\u001B[0m \u001B[43m        \u001B[49m\u001B[43mdevice_state_steps\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtensor\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m1.0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mcpu\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43malpha\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1.0\u001B[39;49m\n\u001B[0;32m    603\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    604\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m    605\u001B[0m     torch\u001B[38;5;241m.\u001B[39m_foreach_add_(device_state_steps, \u001B[38;5;241m1\u001B[39m)\n",
      "\u001B[1;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "execution_count": 32
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:19:08.357145Z",
     "iopub.status.busy": "2025-01-23T08:19:08.356810Z",
     "iopub.status.idle": "2025-01-23T08:19:08.455159Z",
     "shell.execute_reply": "2025-01-23T08:19:08.454531Z",
     "shell.execute_reply.started": "2025-01-23T08:19:08.357124Z"
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:08:44.807934Z",
     "start_time": "2025-03-01T09:08:44.807934Z"
    }
   },
   "source": [
    "plt.plot([i[\"step\"] for i in record[\"train\"][::50]], [i[\"loss\"] for i in record[\"train\"][::50]], label=\"train\")\n",
    "plt.grid()\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 推理"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:19:08.456317Z",
     "iopub.status.busy": "2025-01-23T08:19:08.455886Z",
     "iopub.status.idle": "2025-01-23T08:19:08.467402Z",
     "shell.execute_reply": "2025-01-23T08:19:08.466888Z",
     "shell.execute_reply.started": "2025-01-23T08:19:08.456284Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "#下面的例子是为了说明temperature\n",
    "my_tensor = torch.tensor([0.4,0.6]) #这里是logits\n",
    "\n",
    "probs1 = F.softmax(my_tensor, dim=-1)\n",
    "print(probs1)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:19:08.468453Z",
     "iopub.status.busy": "2025-01-23T08:19:08.468007Z",
     "iopub.status.idle": "2025-01-23T08:19:08.472351Z",
     "shell.execute_reply": "2025-01-23T08:19:08.471801Z",
     "shell.execute_reply.started": "2025-01-23T08:19:08.468420Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:08:44.808935Z",
     "start_time": "2025-03-01T09:08:44.808935Z"
    }
   },
   "source": [
    "my_tensor = torch.tensor([0.2,0.3])  #现在 temperature是2\n",
    "\n",
    "probs1 = F.softmax(my_tensor, dim=-1)\n",
    "print(probs1)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-23T08:19:08.473444Z",
     "iopub.status.busy": "2025-01-23T08:19:08.473043Z",
     "iopub.status.idle": "2025-01-23T08:19:08.486270Z",
     "shell.execute_reply": "2025-01-23T08:19:08.485745Z",
     "shell.execute_reply.started": "2025-01-23T08:19:08.473414Z"
    },
    "jupyter": {
     "outputs_hidden": false
    },
    "ExecuteTime": {
     "end_time": "2025-03-01T09:08:44.809935Z",
     "start_time": "2025-03-01T09:08:44.809935Z"
    }
   },
   "source": [
    "import torch\n",
    "\n",
    "# 创建一个概率分布，表示每个类别被选中的概率\n",
    "# 这里我们有一个简单的四个类别的概率分布\n",
    "prob_dist = torch.tensor([0.1, 0.45, 0.35, 0.1])\n",
    "\n",
    "# 使用 multinomial 进行抽样\n",
    "# num_samples 表示要抽取的样本数量\n",
    "num_samples = 5\n",
    "\n",
    "# 抽取样本，随机抽样，概率越高，抽到的概率就越高\n",
    "samples = torch.multinomial(prob_dist, 1, replacement=True)\n",
    "\n",
    "print(\"概率分布:\", prob_dist)\n",
    "print(\"抽取的样本索引:\", samples)\n",
    "\n",
    "# 显示每个样本对应的概率\n",
    "print(\"每个样本对应的概率:\", prob_dist[samples])"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:19:08.487270Z",
     "iopub.status.busy": "2025-01-23T08:19:08.486885Z",
     "iopub.status.idle": "2025-01-23T08:19:09.442765Z",
     "shell.execute_reply": "2025-01-23T08:19:09.442186Z",
     "shell.execute_reply.started": "2025-01-23T08:19:08.487239Z"
    }
   },
   "source": [
    "def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True):\n",
    "    input_eval = torch.Tensor([char2idx[char] for char in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1) #bacth_size=1, seq_len长度是多少都可以 （1,5）\n",
    "    hidden = None\n",
    "    text_generated = [] #用来保存生成的文本\n",
    "    model.eval()\n",
    "    pbar = tqdm(range(max_len)) # 进度条\n",
    "    print(start_string, end=\"\")\n",
    "    # no_grad是一个上下文管理器，用于指定在其中的代码块中不需要计算梯度。在这个区域内，不会记录梯度信息，用于在生成文本时不影响模型权重。\n",
    "    with torch.no_grad():\n",
    "        for i in pbar:#控制进度条\n",
    "            logits, hidden = model(input_eval, hidden=hidden)\n",
    "            # 温度采样，较高的温度会增加预测结果的多样性，较低的温度则更加保守。\n",
    "            #取-1的目的是只要最后，拼到原有的输入上\n",
    "            logits = logits[0, -1, :] / temperature\n",
    "            # using multinomial to sampling\n",
    "            probs = F.softmax(logits, dim=-1) #算为概率分布\n",
    "            idx = torch.multinomial(probs, 1).item() #从概率分布中抽取一个样本,取概率较大的那些\n",
    "            input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1) #把idx转为tensor\n",
    "            text_generated.append(idx)\n",
    "            if stream:\n",
    "                print(idx2char[idx], end=\"\", flush=True)\n",
    "    return \"\".join([idx2char[i] for i in text_generated])\n",
    "\n",
    "\n",
    "# load checkpoints\n",
    "model.load_state_dict(torch.load(\"checkpoints/text_generation/best.ckpt\", map_location=\"cpu\"))\n",
    "start_string = \"All: \" #这里就是开头，什么都可以\n",
    "res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T08:22:23.716047Z",
     "iopub.status.busy": "2025-01-23T08:22:23.715677Z",
     "iopub.status.idle": "2025-01-23T08:22:23.720173Z",
     "shell.execute_reply": "2025-01-23T08:22:23.719579Z",
     "shell.execute_reply.started": "2025-01-23T08:22:23.716020Z"
    },
    "tags": []
   },
   "source": [
    "res"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
